# Distributed training example with managed spot job + automatic
# checkpoint/restore.
#
# Creates a shared s3 mount across workers that is used to read and write
# checkpoints. Save checkpoints every 10 epochs.
#
# Adapted from romil/michaelzhiluo/michaelvll's examples.

name: resnet

resources:
    accelerators: V100
    cloud: aws
    use_spot: true
    spot_recovery: FAILOVER

num_nodes: 2

file_mounts:
    /checkpoint:
        name: # NOTE: Fill in your bucket name for checkpoints
        mode: MOUNT

    ~/resnet_ddp:
        name: # NOTE: Fill in your bucket name for codes
        source: ./examples/spot/resnet_ddp
        persistent: false
        mode: COPY

setup: |
    # Fill in your wandb key: copy from https://wandb.ai/authorize
    echo export WANDB_API_KEY=[YOUR-WANDB-API-KEY] >> ~/.bashrc

    pip3 install --upgrade pip
    cd ~/resnet_ddp && pip3 install -r requirements.txt
    mkdir -p data  && mkdir -p saved_models && cd data && \
    wget -c --quiet https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
    tar -xvzf cifar-10-python.tar.gz

    mkdir -p /checkpoint/torch_ddp_resnet/
    mkdir -p /checkpoint/wandb

run: |
    cd ~/resnet_ddp

    # modify your run id for each different run!
    run_id="resnet-run-1"

    num_nodes=`echo "$SKYPILOT_NODE_IPS" | wc -l`
    master_addr=`echo "$SKYPILOT_NODE_IPS" | head -n1`
    python3 -m torch.distributed.launch --nproc_per_node=1 \
    --nnodes=$num_nodes --node_rank=${SKYPILOT_NODE_RANK} --master_addr=$master_addr \
    --master_port=8008 resnet_ddp.py --num_epochs 100000 --model_dir /checkpoint/torch_ddp_resnet/ \
    --resume --model_filename resnet_distributed-with-epochs.pth --run_id $run_id --wandb_dir /checkpoint/
